{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Explaining a Question Answering Transformers Model\n",
"\n",
"Here we demonstrate how to explain the output of a question answering model that predicts which range of the context text contains the answer to a given question."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import transformers\n",
"\n",
"import shap\n",
"\n",
"# load the model\n",
"pmodel = transformers.pipeline(\"question-answering\")\n",
"tokenized_qs = None # variable to store the tokenized data\n",
"\n",
"\n",
"# define two predictions, one that outputs the logits for the range start,\n",
"# and the other for the range end\n",
"def f(questions, tokenized_qs, start):\n",
" outs = []\n",
" for q in questions:\n",
" idx = np.argwhere(np.array(tokenized_qs[\"input_ids\"]) == pmodel.tokenizer.sep_token_id)[\n",
" 0, 0\n",
" ] # this code assumes that there is only one sentence in data\n",
" d = tokenized_qs.copy()\n",
" d[\"input_ids\"][:idx] = q[:idx]\n",
" d[\"input_ids\"][idx + 1 :] = q[idx + 1 :]\n",
" out = pmodel.model.forward(**{k: torch.tensor(d[k]).reshape(1, -1) for k in d})\n",
" logits = out.start_logits if start else out.end_logits\n",
" outs.append(logits.reshape(-1).detach().numpy())\n",
" return outs\n",
"\n",
"\n",
"def tokenize_data(data):\n",
" for q in data:\n",
" question, context = q.split(\"[SEP]\")\n",
" tokenized_data = pmodel.tokenizer(question, context)\n",
" return tokenized_data # this code assumes that there is only one sentence in data\n",
"\n",
"\n",
"def f_start(questions):\n",
" return f(questions, tokenized_qs, True)\n",
"\n",
"\n",
"def f_end(questions):\n",
" return f(questions, tokenized_qs, False)\n",
"\n",
"\n",
"# attach a dynamic output_names property to the models so we can plot the tokens at each output position\n",
"def out_names(inputs):\n",
" question, context = inputs.split(\"[SEP]\")\n",
" d = pmodel.tokenizer(question, context)\n",
" return [pmodel.tokenizer.decode([id]) for id in d[\"input_ids\"]]\n",
"\n",
"\n",
"f_start.output_names = out_names\n",
"f_end.output_names = out_names"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Explain the starting positions\n",
"\n",
"Here we explain the starting range predictions of the model. Note that because the model output depends on the length of the model input, is is important that we pass the model's native tokenizer for masking, so that when we hide portions of the text we can retain the same number of tokens and hence the same meaning for each output position."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4b4f99a75cbc4f3984b0bfb65fbaaff9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/498 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Partition explainer: 2it [00:32, 32.86s/it] \n"
]
},
{
"data": {
"text/html": [
"\n",
"
\n",
"